{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "637adcb1",
   "metadata": {},
   "source": [
    "## 生存分析\n",
    "\n",
    "一般使用KM算法记性单一变量拟合，同事可以分变量预测效果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f50c8eb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from onekey_algo import OnekeyDS as okds\n",
    "\n",
    "# 设置绘图参数。\n",
    "plt.rcParams['figure.figsize'] = (10.0, 8.0)\n",
    "plt.rcParams['font.sans-serif'] = 'SimHei'\n",
    "plt.rcParams['axes.unicode_minus'] = False\n",
    "\n",
    "mydir = okds.survival\n",
    "data = pd.read_csv(mydir)\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49fcae05",
   "metadata": {},
   "source": [
    "### KM Estimator\n",
    "To estimate the survival function, we first will use the Kaplan-Meier Estimate, defined:\n",
    "\n",
    "$\\hat{S}(t) = \\prod_{t_i \\lt t} \\frac{n_i - d_i}{n_i} $\n",
    "\n",
    "where $d_i$ are the number of death events at time $t$ and $n_i$ is the number of subjects at risk of death just prior to time t."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b69afca",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lifelines import KaplanMeierFitter\n",
    "kmf = KaplanMeierFitter()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3bfcca4",
   "metadata": {},
   "source": [
    "### 定义生存时间和最终状态\n",
    "Other ways to estimate the survival function in lifelines are discussed below.\n",
    "\n",
    "For this estimation, we need the duration each leader was/has been in office, and whether or not they were observed to have left office (leaders who died in office or were in office in 2008, the latest date this data was record at, do not have observed death events)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d5b1bb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "T = data[\"duration\"]\n",
    "E = data[[\"result\"]]\n",
    "\n",
    "kmf.fit(T, event_observed=E)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "708c07d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "kmf.plot_survival_function()\n",
    "plt.title('Survival function of political regimes')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b81d911",
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = plt.subplot(111)\n",
    "\n",
    "dem = (data[\"smoke\"] == 1)\n",
    "\n",
    "kmf.fit(T[dem], event_observed=E[dem], label=\"smoke\")\n",
    "kmf.plot_survival_function(ax=ax)\n",
    "\n",
    "kmf.fit(T[~dem], event_observed=E[~dem], label=\"Non-smoke\")\n",
    "kmf.plot_survival_function(ax=ax)\n",
    "\n",
    "plt.title(\"Lifespans of different global regimes\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04b44640",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "ax = plt.subplot(111)\n",
    "\n",
    "t = np.linspace(20, 80, 41)\n",
    "kmf.fit(T[dem], event_observed=E[dem], timeline=t, label=\"Democratic Regimes\")\n",
    "ax = kmf.plot_survival_function(ax=ax)\n",
    "\n",
    "kmf.fit(T[~dem], event_observed=E[~dem], timeline=t, label=\"Non-democratic Regimes\")\n",
    "ax = kmf.plot_survival_function(ax=ax)\n",
    "\n",
    "plt.title(\"Lifespans of different global regimes\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a93d4c7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "regime_types = data['Tstage'].unique()\n",
    "\n",
    "for i, regime_type in enumerate(regime_types):\n",
    "    ax = plt.subplot(2, 2, i + 1)\n",
    "\n",
    "    ix = data['Tstage'] == regime_type\n",
    "    kmf.fit(T[ix], E[ix], label=regime_type)\n",
    "    kmf.plot_survival_function(ax=ax, legend=False)\n",
    "\n",
    "    plt.title(regime_type)\n",
    "    plt.xlim(20, 83)\n",
    "\n",
    "    if i==0:\n",
    "        plt.ylabel('Frac. in power after $n$ years')\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20c9fd42",
   "metadata": {},
   "outputs": [],
   "source": [
    "kmf = KaplanMeierFitter().fit(T, E, label=\"all_regimes\")\n",
    "kmf.plot_survival_function(at_risk_counts=True)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf716b5b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
